# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils of MoE Module."""

__all__ = [
    "topk_routing_with_score_function",
]

from typing import Optional

from mindspore import Tensor, mint
import mindspore.common.dtype as mstype
from mindspore.ops.auto_generate import FusedAddTopKDiv


def softmax_score_function(x):
    """Compute softmax probability."""
    return mint.softmax(x, dim=-1, dtype=mstype.float32)


def sigmoid_score_function(x):
    """Compute sigmoid probability."""
    return mint.sigmoid(x)


# Note: Define mapping from score function names to their implementations.
# When adding new score functions, need to update score function map
SCORE_FUNCTION_MAP = {
    "softmax": softmax_score_function,
    "sigmoid": sigmoid_score_function,
}


def group_limited_topk(
        scores: Tensor,
        topk: int,
        num_experts: int,
        num_groups: int,
        group_topk: int,
):
    """Perform top-k routing on a subset of expert groups.

    Args:
        scores (Tensor): Softmax scores generated by the router.
        topk (int): The number of experts to select for each token.
        num_experts (int): The number of experts.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of groups selected for each token.


    Returns:
        Tuple[Tensor, Tensor]: Probs and indices tensor.
    """
    # Organize the experts into groups
    # Select groups based on sum of top-(topk/group_topk) routing scores within each group
    group_scores = (
        scores.reshape(-1, num_groups, num_experts // num_groups).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
    )
    group_idx = mint.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
    group_mask = mint.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)

    # Mask the experts based on selection groups
    score_mask = group_mask.unsqueeze(-1)
    score_mask = score_mask.repeat_interleave(repeats=num_experts // num_groups, dim=-1)
    score_mask = score_mask.reshape(-1, num_experts)

    # Select top-k experts from masked scores
    masked_scores = scores.masked_fill(~score_mask.bool(), 0.0)
    probs, top_indices = mint.topk(masked_scores, k=topk, dim=-1)

    return probs, top_indices


def topk_routing_with_score_function(
        logits: Tensor,
        topk: int,
        num_experts: int,
        num_groups: Optional[int] = None,
        group_topk: Optional[int] = None,
        scaling_factor: Optional[float] = None,
        score_function: str = "softmax",
        expert_bias: Optional[Tensor] = None,
        norm_topk_prob: Optional[bool] = True,
        fused: bool = False,
):
    """Compute the routing probabilities and map for top-k selection with score function.
    Args:
        logits (Tensor): Logits tensor with shape [num_tokens, num_experts].
        topk (int): The number of experts to select for each token.
        num_experts (int): The number of experts.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of selected groups for each token.
        scaling_factor (float): Scaling factor of routing score in top-k selection.
        score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
        expert_bias (Tensor): The bias added to logits for expert routing.
        norm_topk_prob (bool, optional): Whether to normalize the top-k probabilities.
        fused (bool): Whether to use fused op.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - expert_weight (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
              the routing probabilities for each token to each expert.
            - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
              indicating which experts were selected for each token. True values represent
              the selected experts.
    """
    if fused:
        if group_topk:
            fused_add_topk_div = FusedAddTopKDiv()
            # Fused operator requires that bias must be tensor and cannot be none
            expert_bias = expert_bias if expert_bias is not None else \
                            mint.zeros((num_experts,), dtype=mstype.float32)
            return fused_add_topk_div(
                logits,
                expert_bias,
                num_groups,
                group_topk,
                topk // group_topk,
                topk,
                0,
                True,
                scaling_factor
            )

    def compute_topk(scores, topk, num_groups=None, group_topk=None):
        if group_topk:
            return group_limited_topk(
                scores=scores,
                topk=topk,
                num_experts=num_experts,
                num_groups=num_groups,
                group_topk=group_topk,
            )
        return mint.topk(scores, k=topk, dim=-1)

    score_func = SCORE_FUNCTION_MAP.get(score_function)
    if score_func is None:
        raise ValueError(f"Invalid score_function: {score_function}. "
                         f"Supported functions are: {list(SCORE_FUNCTION_MAP.keys())}")

    scores = score_func(logits)
    if expert_bias is not None:
        scores_for_routing = mint.add(scores, expert_bias.unsqueeze(0))
        _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
        expert_weight = scores.gather(1, top_indices)
    else:
        expert_weight, top_indices = compute_topk(scores, topk, num_groups, group_topk)

    if norm_topk_prob and topk > 1:
        expert_weight = mint.div(expert_weight, mint.add(mint.sum(expert_weight, -1, True), 1e-20))
    if scaling_factor:
        expert_weight = mint.mul(expert_weight, scaling_factor)

    routing_map = top_indices.astype(mstype.int32)

    return expert_weight, routing_map
